{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "from rllm.system_prompts import LCB_FORMATTING_MESSAGE_WITH_STARTER_CODE, LCB_SYSTEM_MESSAGE_GENERIC\n",
    "\n",
    "train_dataset = load_dataset(\"KodCode/KodCode-V1\", split=\"train\")\n",
    "print(\"Training set:\", train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter Rules\n",
    "\n",
    "`style`: instruct\n",
    "\n",
    "`subset`: Leetcode, Codeforces, Code Contests, Taco, Apps\n",
    "\n",
    "`GPT4o Pass Count`: < 9\n",
    "\n",
    "`Benchmark Similarity`: < 0.9\n",
    "\n",
    "Test count: >= 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kodcode_filter(x):\n",
    "    if x[\"subset\"] in [\"Leetcode\", \"Codeforces\", \"Code_Contests\", \"Apps\", \"Taco\"]:\n",
    "        if x[\"style\"] == \"instruct\":\n",
    "            if x[\"gpt_pass_trial_num\"] < 9:\n",
    "                if x[\"benchmark_similarity\"] < 0.9:\n",
    "                    if x[\"test_code\"].count(\"def\") >= 8:\n",
    "                        return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def_filtered_dataset = train_dataset.filter(kodcode_filter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bad Data Removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bad_ids = [\"Codeforces_12376_I\"]\n",
    "error_filtered_dataset = def_filtered_dataset.filter(lambda x: x[\"question_id\"] not in bad_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_test_info(entry):\n",
    "    \"\"\"\n",
    "    Format test_info into Python starter code with docstring and function declaration.\n",
    "\n",
    "    Args:\n",
    "        test_info: List of test info dictionaries containing function declaration and docstring\n",
    "\n",
    "    Returns:\n",
    "        Formatted Python starter code string\n",
    "    \"\"\"\n",
    "    # Return empty if no test info\n",
    "    test_info = entry.get(\"test_info\", [])\n",
    "    if not test_info:\n",
    "        return \"\"\n",
    "\n",
    "    # Get the function declaration and docstring from first test info\n",
    "    tests = entry[\"test_code\"]\n",
    "    solution_import = \"\\n\".join([line for line in tests.split(\"\\n\") if line.strip().startswith(\"from solution import\")])\n",
    "\n",
    "    solution_funcs = []\n",
    "    # Get all the solution functions from solution import\n",
    "    if solution_import:\n",
    "        # Extract function names from the import statement\n",
    "        import_parts = solution_import.replace(\"from solution import \", \"\").strip()\n",
    "        # Split by commas and strip whitespace\n",
    "        solution_funcs = [func.strip() for func in import_parts.split(\",\")]\n",
    "\n",
    "    for t in test_info:\n",
    "        if t[\"function_name\"] in tests and t[\"function_name\"] not in solution_funcs:\n",
    "            solution_funcs.append(t[\"function_name\"])\n",
    "    # Check test infos\n",
    "    relevant_test_infos = {}\n",
    "    for t_info in test_info:\n",
    "        func_dec = t_info.get(\"function_name\", \"\")\n",
    "        if not func_dec:\n",
    "            continue\n",
    "\n",
    "        if func_dec in solution_funcs:\n",
    "            relevant_test_infos[func_dec] = t_info\n",
    "\n",
    "    func_decl_instruction = \"The code you write must contain the following functions or classes:\\n\\n\"\n",
    "    for func_name in solution_funcs:\n",
    "        func_decl_instruction += f\"{func_name}\\n\"\n",
    "\n",
    "    instruction = f\"{LCB_FORMATTING_MESSAGE_WITH_STARTER_CODE}\\n\\n\"\n",
    "    for func_name in solution_funcs:\n",
    "        if func_name in relevant_test_infos:\n",
    "            instruction += f\"{relevant_test_infos[func_name]['function_declaration']}\\n\"\n",
    "            doc_string = relevant_test_infos[func_name].get(\"docstring\", \"\")\n",
    "            if doc_string:\n",
    "                indented_docstring = \"\\n\".join(f\"    {line}\" for line in doc_string.split(\"\\n\"))\n",
    "                instruction += f'    \"\"\"\\n{indented_docstring}\\n    \"\"\"\\n\\n'\n",
    "            instruction += \"\\n\"\n",
    "\n",
    "    return func_decl_instruction + \"\\n\\n\" + instruction\n",
    "\n",
    "\n",
    "inst = format_test_info(error_filtered_dataset[1251])\n",
    "print(inst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dedupe Questions\n",
    "\n",
    "We deduped based on cosine similarity, with a threshold of 0.792"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "from rllm.utils import RAG\n",
    "\n",
    "questions = [entry[\"question\"] for entry in error_filtered_dataset]\n",
    "\n",
    "rag = RAG(docs=questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rag_cutoff = 0.792\n",
    "\n",
    "indices_to_remove = set()\n",
    "\n",
    "for i in tqdm(range(len(error_filtered_dataset))):\n",
    "    if i in indices_to_remove:\n",
    "        continue\n",
    "    similars = rag.top_k(error_filtered_dataset[i][\"question\"], k=5)\n",
    "    for entry in similars[1:]:\n",
    "        if entry[\"score\"].item() > rag_cutoff:\n",
    "            indices_to_remove.add(entry[\"index\"])\n",
    "print(len(indices_to_remove))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deduped_dataset = error_filtered_dataset.filter(lambda x, idx: idx not in indices_to_remove, with_indices=True)\n",
    "\n",
    "print(f\"Original dataset size: {len(error_filtered_dataset)}\")\n",
    "print(f\"Number of indices to remove: {len(indices_to_remove)}\")\n",
    "print(f\"Filtered dataset size: {len(deduped_dataset)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = []\n",
    "for entry in deduped_dataset:\n",
    "    tests = entry[\"test_code\"]\n",
    "    solution_import = \"\\n\".join([line for line in tests.split(\"\\n\") if line.strip().startswith(\"from solution import\")])\n",
    "    tests = \"\\n\".join([line for line in tests.split(\"\\n\") if not line.strip().startswith(\"from solution import\")])\n",
    "\n",
    "    instruction = format_test_info(entry)\n",
    "\n",
    "    problem = f\"\"\"{LCB_SYSTEM_MESSAGE_GENERIC} Make sure your code consists of just standalone classes and functions, which can then be tested in a pytest suite for the correctness of your function/class using assertions on return values.\n",
    "No reading from stdin or writing to stdout is allowed.\n",
    "    \n",
    "{entry[\"question\"].strip()}\n",
    "\n",
    "{instruction}\n",
    "\"\"\"\n",
    "\n",
    "    if len(tests) == 0:\n",
    "        continue\n",
    "    new_entry = {\n",
    "        \"problem\": problem,\n",
    "        \"solutions\": entry[\"solution\"],\n",
    "        \"tests\": tests,\n",
    "    }\n",
    "\n",
    "    dataset.append(new_entry)\n",
    "\n",
    "print(f\"Dataset size: {len(dataset)}\")\n",
    "\n",
    "output_dir = os.path.abspath(\"../../train/code\")\n",
    "output_file = os.path.join(output_dir, \"kodcode.json\")\n",
    "\n",
    "with open(output_file, \"w\") as f:\n",
    "    json.dump(dataset, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
